import os
import torch
import json
import transformers
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
import argparse
import logging
import atexit

from transformers import StoppingCriteria, StoppingCriteriaList

from algo.mello_eval_loop import mello_eval_loop
from algo.gwalk_eval_loop import gwalk_eval_loop
from utils.utils import get_sent_embeddings, get_ent_alias, get_ent_rel_id, process_kg
from utils.utils import get_sent_embeddings
from mquake_dataset import MQUAKE

logger = logging.getLogger(__name__)
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO
)


class StoppingCriteriaSub(StoppingCriteria):
    
    def __init__(self, stops=[], length=5):
        StoppingCriteria.__init__(self),
        self.stops = stops
        self.length = length
    
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops=[]):
        exit = True
        for i in range(1, self.length, 1):
            # print(input_ids[0][-i], self.stops[-i])
            if input_ids[0][-i] != self.stops[-i]:
                exit = False
        return exit


def save_logger_setup(logger_to_save, file_path, delete_duplicate_output_file):
    """
        existing file_path will result in a new file with an different name to be saved.
    """
    
    # Create a file handler and set the formatter
    if os.path.exists(file_path):
        if delete_duplicate_output_file:
            os.remove(file_path)
        else:
            file_path = file_path.split('.txt')[0] + "(0).txt"
    else:
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
    
    file_handler = logging.FileHandler(file_path)
    
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)
    
    # Add the file handler to the logger
    logger_to_save.addHandler(file_handler)
    atexit.register(save_log_to_file)


# Function to be called before program exits
def save_log_to_file():
    logger.info("Saving logging information to a file...")


def print_arguments(arguments):
    res = ""
    for key, value in arguments.items():
        res = res + str(key) + ": " + str(value) + "\n"
    return res[:-1]


def main():
    parser = argparse.ArgumentParser(description='command line arguments')
    parser.add_argument('--model_name', type=str, help='Model for the edits.')
    parser.add_argument('--device', type=str, default='cuda', help='Cuda or CPU?')
    parser.add_argument('--file_path', type=str, help='the main directory path to the files')
    parser.add_argument('--seed', type=int, default=100, help='random seed number')
    parser.add_argument('--output_dir', type=str, help='output directory')
    parser.add_argument('--delete_duplicate_output_file', type=bool, help='Delete duplicate output file?')
    parser.add_argument('--edit_num', type=int, default=3000, help='number of questions to edit')
    parser.add_argument('--print_prompt', type=bool, default=False, help='print the prompt for debug')
    parser.add_argument('--dataset_name', type=str, default="CF-3k", help='default counterfactual')
    parser.add_argument('--algo', type=str, default='mello')
    parser.add_argument('--masking', type=bool, default=True, help="whether to use masking")
    parser.add_argument('--retriever', default='facebook/contriever-msmarco', type=str, help='The retriever model to use')
    parser.add_argument('--retriever_threshold', default=0.845, help='Threshold to use for the retriever')
    parser.add_argument('--huggingface_key', default='YOUR_HUGGINGFACE_KEY', type=str, help='Huggingface API key')
    
    
    # parse arguments from
    args = parser.parse_args()
    model_name = args.model_name
    device = args.device
    file_path = args.file_path
    seed_num = args.seed
    output_dir = args.output_dir
    edit_num = args.edit_num
    retriever_name = args.retriever
    retriever_threshold = args.retriever_threshold
    huggingface_key = args.huggingface_key
    
    delete_duplicate_output_file = args.delete_duplicate_output_file
    print_prompt = args.print_prompt
    masking = args.masking
    # masking = False
    logger.info("===============================\n")
    logger.info("You probably would like to have 'masking==False' when evaluating on the old dataset.")
    logger.info("Masking: %s" % str(masking))
    logger.info("\n===============================")

    ########## DEBUG ############
    print_prompt = True
    ######################
    
    dataset_name = args.dataset_name
    algo = args.algo
    break_down_into_subquestions_front_space = 0
    name_of_the_run = f"{algo}_{model_name}_{dataset_name}_{edit_num}"
    result_file_path = file_path + f"raw_answer_dict_{name_of_the_run}.json"

    save_logger_setup(logger, output_dir + "%s.txt" % name_of_the_run, delete_duplicate_output_file)
    
    arguments = vars(args)
    logger.info("Args are parsed. And as follow: \n %s" % print_arguments(arguments))
    
    if algo in ['ice', 'ike']:
        from transformers import set_seed
        from vllm import LLM, SamplingParams
        from algo.ice_eval_loop import ice_eval_loop
        import random
        from huggingface_hub import login
        
        set_seed(seed_num)
        random.seed(seed_num)
        
        # please enter your own huggingface key.
        # login("your_own_huggingface_key")
        login(huggingface_key)
        
        if model_name == "vicuna-7b":
            full_model_name = "lmsys/vicuna-7b-v1.5"
        elif model_name == "mistral-7b":
            full_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
        elif model_name == "llama3-8b":
            full_model_name = "meta-llama/Meta-Llama-3-8B"
        else:
            raise ValueError("Model <%s> not implemented yet." % model_name)
        
        # Using VLLM
        model = LLM(model=full_model_name, quantization='fp8') 
        sampling_params = SamplingParams(best_of=1, temperature=0, max_tokens=50)
        
        # data loading.
        mquake_dataset = MQUAKE(dataset_name, file_path, edit_num, seed_num)
        
        # a set of case_ids from 1,2, ..., to dataset-length
        rand_list = mquake_dataset.get_randlist()
        
        if algo == 'ice':
            ice_eval_loop(mquake_dataset=mquake_dataset,
                      edited_caseid=rand_list,
                      model=model,
                      sampling_params=sampling_params,
                      result_file_path=result_file_path,
                      masking=masking,
                      dataset_name=dataset_name)
        elif algo == 'ike':
            from algo.ike_eval_loop import ike_eval_loop
            
            with open(file_path + 'prompts/counterfact.json', 'r') as f:
                demos = json.load(f)
                
            with open(file_path + 'prompts/corpus_idx.txt', 'r') as fIn:
                ike_lines = fIn.readlines()
                ike_lines = [line[:-1] for line in ike_lines]
                
                corpus_idx = [[int(idx) for idx in line.split()] for line in ike_lines]
            
            ike_eval_loop(mquake_dataset=mquake_dataset,
                      edited_caseid=rand_list,
                      model=model,
                      sampling_params=sampling_params,
                      result_file_path=result_file_path,
                      masking=masking,
                      dataset_name=dataset_name,
                      demos=demos,
                      corpus_idx=corpus_idx)

    elif algo in ['mello', 'gwalk', 'pokemqa']:
        
        if model_name == "vicuna-7b":
            model_tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5", padding_side='left')
            model = AutoModelForCausalLM.from_pretrained("lmsys/vicuna-7b-v1.5", torch_dtype=torch.float16).to(device)
            
            # llm generation stopping criteria:
            # retrieve facts:
            sc_facts = StoppingCriteriaList([StoppingCriteriaSub(stops=[8015, 2546, 1490, 2114, 29901])])
            
            # subquestion:
            sc_subq = StoppingCriteriaList([StoppingCriteriaSub(stops=[13, 4035, 12470, 29901])])
            
            # Done.
            sc_done = StoppingCriteriaList([StoppingCriteriaSub(stops=[25632, 29889], length=2)])
            
            # this ends he block:
            sc_end_block = StoppingCriteriaList([StoppingCriteriaSub(stops=[2023, 4515, 1996, 3796])])
        
        elif model_name == "mistral-7b":  # mistral-7b
            from huggingface_hub import login
            
            # please enter your own huggingface key.
            # login("your_own_huggingface_key")
            login(huggingface_key)
            
            model_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
            model_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
            model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", torch_dtype=torch.float16).to(device)

            # llm generation stopping criteria:
            # retrieve facts:
            sc_facts = StoppingCriteriaList([StoppingCriteriaSub(stops=[8637, 10212, 286, 1639, 28747])])

            # subquestion:
            sc_subq = StoppingCriteriaList([StoppingCriteriaSub(stops=[5078, 17496, 28747])])

            # Done.
            sc_done = StoppingCriteriaList([StoppingCriteriaSub(stops=[384, 538, 28723], length=3)])

            # this ends the block:
            sc_end_block = StoppingCriteriaList([StoppingCriteriaSub(stops=[851, 9675, 272, 2724, 28723])])
            
        elif model_name == "llama3-8b":  # llama3-8b
            from huggingface_hub import login
            
            # please enter your own huggingface key.
            # login("your_own_huggingface_key")
            login(huggingface_key)


            model_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
            model_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
            model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B", torch_dtype=torch.float16).to(device)


            # llm generation stopping criteria:
            # retrieve facts: 
            sc_facts = StoppingCriteriaList([StoppingCriteriaSub(stops=[12289, 83712, 2144, 25], length=4)])

            # subquestion:
            sc_subq = StoppingCriteriaList([StoppingCriteriaSub(stops=[3214, 7998, 25], length=3)])

            # Done.
            sc_done = StoppingCriteriaList([StoppingCriteriaSub(stops=[17911, 13, 13], length=2)])

            # this ends the block:
            sc_end_block = StoppingCriteriaList([StoppingCriteriaSub(stops=[2028, 10548, 279, 2565, 13])])
        
        else:
            raise ValueError("Model <%s> not implemented yet." % model_name)
        
        # data loading.
        mquake_dataset = MQUAKE(dataset_name, file_path, edit_num, seed_num)
        
        # a set of case_ids from 1,2, ..., to dataset-length
        rand_list = mquake_dataset.get_randlist()
        
        if algo == 'mello':
            # retrieve model:
            contriever = AutoModel.from_pretrained(retriever_name).to(device)
            tokenizer = AutoTokenizer.from_pretrained(retriever_name)
            
            if edit_num:
                with open(file_path + 'prompts/MeLLo-prompt.txt', 'r', encoding='utf-8') as f:
                    task_prompt = f.read()
            else:
                with open(file_path + 'prompts/MeLLo-prompt-baseline.txt', 'r', encoding='utf-8') as f:
                    task_prompt = f.read()
            
            logger.info("Prepare works are Done!")
            
            mello_eval_loop(mquake_dataset=mquake_dataset, task_prompt=task_prompt,
                            sc_fact=sc_facts, rand_list=rand_list, model=model, model_tokenizer=model_tokenizer,
                            device=device, contriever=contriever, tokenizer=tokenizer, print_prompt=print_prompt,
                            logger=logger, masking=masking, result_file_path=result_file_path)
        
        elif algo == 'gwalk':
            # retrieve model:
            contriever = AutoModel.from_pretrained(retriever_name).to(device)
            tokenizer = AutoTokenizer.from_pretrained(retriever_name)
            
            with open(file_path + 'prompts/fill_out_ga_w_blank2.txt', 'r', encoding='utf-8') as f:
                task_prompt = f.read()
            
            with open(file_path + f'prompts/subq_breakdown.txt', 'r', encoding='utf-8') as f:
                breakdown_prompt = f.read()
            
            with open(file_path + 'prompts/relation2subq_prompt2.txt', 'r', encoding='utf-8') as f:
                relation2subq_prompt = f.read()
            
            with open(file_path + 'prompts/subq2rel.json', 'r') as f:
                rel2subq = json.load(f)
            
            entity2id, id2entity, rel2id, id2rel = get_ent_rel_id(file_path, dataset_name)
            _, kg_s_r_o, rels, ents = process_kg(mquake_dataset.get_dataset(), rand_list, id2entity, id2rel)
            ent2alias, _ = get_ent_alias(mquake_dataset.get_dataset(), entity2id)
            if rels:
                rel_emb = get_sent_embeddings(rels, contriever, tokenizer)
            else:
                rel_emb = []
            
            if ents:
                ent_emb = get_sent_embeddings(ents, contriever, tokenizer)
            else:
                ent_emb = []
            
            logger.info("Prepare works are Done!")
            
            gwalk_eval_loop(dataset=mquake_dataset,
                            task_prompt=task_prompt,
                            sc_facts=sc_facts,
                            model=model,
                            model_tokenizer=model_tokenizer,
                            device=device,
                            rels=rels,
                            rel_emb=rel_emb,
                            contriever=contriever,
                            tokenizer=tokenizer,
                            entity2id=entity2id,
                            ent2alias=ent2alias,
                            rel2id=rel2id,
                            kg_s_r_o=kg_s_r_o,
                            id2entity=id2entity,
                            ent_emb=ent_emb,
                            ents=ents,
                            rand_list=rand_list,
                            print_prompt=print_prompt,
                            breakdown_prompt=breakdown_prompt,
                            sc_end_block=sc_end_block,
                            relation2subq_prompt=relation2subq_prompt,
                            sc_done=sc_done,
                            logger=logger,
                            masking=masking,
                            result_file_path=result_file_path,
                            rel2subq=rel2subq,
                            retriever_threshold=retriever_threshold
                            )
            
        elif algo == 'pokemqa':
            from algo.pokemqa_eval_loop import pokemqa_eval_loop
            
            cls_name = file_path + "PokeMQA/detector-checkpoint/detector-ckpt"
            seq_name = file_path + "PokeMQA/detector-checkpoint/dis-ckpt"

            pokemqa_model_name = file_path + "PokeMQA/detector-checkpoint/raw_distilbert"
            tokenizer = transformers.AutoTokenizer.from_pretrained(pokemqa_model_name)

            classifier = transformers.AutoModel.from_pretrained(f"{cls_name}").to(device)
            seq_cls = transformers.AutoModelForSequenceClassification.from_pretrained(f"{seq_name}").to(device)

            with open(file_path + 'PokeMQA/prompts/woKGprompt.txt','r') as f:
                task_wokg_prompt = f.read()
                
            pokemqa_eval_loop(mquake_dataset=mquake_dataset,
                            dataset_name=dataset_name,
                            masking=masking,
                            tokenizer=tokenizer, 
                            classifier=classifier,
                            task_wokg_prompt=task_wokg_prompt,
                            seq_cls=seq_cls, 
                            result_file_path=result_file_path,
                            device=device,
                            model=model,
                            model_tokenizer=model_tokenizer)
    else:
        raise ValueError(f"Algo {algo} not available.")
    
    logger.info("Job finished.")


if __name__ == '__main__':
    main()
